from dataclasses import dataclass, field
from functools import partial

from opentelemetry import metrics
from opentelemetry.metrics import Counter, Histogram
from opentelemetry.metrics._internal import Gauge

from letta.helpers.singleton import singleton
from letta.otel.metrics import get_letta_meter


@singleton
@dataclass(frozen=True)
class MetricRegistry:
    """Registry of all application metrics

    Metrics are composed of the following:
        - name
        - description
        - unit: UCUM unit of the metric (i.e. 'By' for bytes, 'ms' for milliseconds, '1' for count
        - bucket_bounds (list[float] | None): the explicit bucket bounds for histogram metrics

        and instruments are of types Counter, Histogram, and Gauge

    The relationship between the various models is as follows:
        project_id -N:1-> base_template_id -N:1-> template_id -N:1-> agent_id
        agent_id -1:1+-> model_name
        agent_id -1:N -> tool_name
    """

    Instrument = Counter | Histogram | Gauge
    _metrics: dict[str, Instrument] = field(default_factory=dict, init=False)
    _meter: metrics.Meter = field(init=False)

    def __post_init__(self):
        object.__setattr__(self, "_meter", get_letta_meter())

    def _get_or_create_metric(self, name: str, factory):
        """Lazy initialization of metrics."""
        if name not in self._metrics:
            self._metrics[name] = factory()
        return self._metrics[name]

    # (includes base attributes: project, template_base, template, agent)
    @property
    def user_message_counter(self) -> Counter:
        return self._get_or_create_metric(
            "count_user_message",
            partial(
                self._meter.create_counter,
                name="count_user_message",
                description="Counts the number of messages sent by the user",
                unit="1",
            ),
        )

    # (includes tool_name, tool_execution_success, & step_id on failure)
    @property
    def tool_execution_counter(self) -> Counter:
        return self._get_or_create_metric(
            "count_tool_execution",
            partial(
                self._meter.create_counter,
                name="count_tool_execution",
                description="Counts the number of tools executed.",
                unit="1",
            ),
        )

    # project_id + model
    @property
    def ttft_ms_histogram(self) -> Histogram:
        return self._get_or_create_metric(
            "hist_ttft_ms",
            partial(
                self._meter.create_histogram,
                name="hist_ttft_ms",
                description="Histogram for the Time to First Token (ms)",
                unit="ms",
            ),
        )

    # (includes model name)
    @property
    def llm_execution_time_ms_histogram(self) -> Histogram:
        return self._get_or_create_metric(
            "hist_llm_execution_time_ms",
            partial(
                self._meter.create_histogram,
                name="hist_llm_execution_time_ms",
                description="Histogram for LLM execution time (ms)",
                unit="ms",
            ),
        )

    # (includes tool name)
    @property
    def tool_execution_time_ms_histogram(self) -> Histogram:
        return self._get_or_create_metric(
            "hist_tool_execution_time_ms",
            partial(
                self._meter.create_histogram,
                name="hist_tool_execution_time_ms",
                description="Histogram for tool execution time (ms)",
                unit="ms",
            ),
        )

    @property
    def step_execution_time_ms_histogram(self) -> Histogram:
        return self._get_or_create_metric(
            "hist_step_execution_time_ms",
            partial(
                self._meter.create_histogram,
                name="hist_step_execution_time_ms",
                description="Histogram for step execution time (ms)",
                unit="ms",
            ),
        )

    # TODO (cliandy): instrument this
    @property
    def message_cost(self) -> Histogram:
        return self._get_or_create_metric(
            "hist_message_cost_usd",
            partial(
                self._meter.create_histogram,
                name="hist_message_cost_usd",
                description="Histogram for cost of messages (usd) per step",
                unit="usd",
            ),
        )

    # (includes model name)
    @property
    def message_output_tokens(self) -> Histogram:
        return self._get_or_create_metric(
            "hist_message_output_tokens",
            partial(
                self._meter.create_histogram,
                name="hist_message_output_tokens",
                description="Histogram for output tokens generated by LLM per step",
                unit="1",
            ),
        )

    # (includes endpoint_path, method, status_code)
    @property
    def endpoint_e2e_ms_histogram(self) -> Histogram:
        return self._get_or_create_metric(
            "hist_endpoint_e2e_ms",
            partial(
                self._meter.create_histogram,
                name="hist_endpoint_e2e_ms",
                description="Histogram for endpoint e2e time (ms)",
                unit="ms",
            ),
        )

    # (includes endpoint_path, method, status_code)
    @property
    def endpoint_request_counter(self) -> Counter:
        return self._get_or_create_metric(
            "count_endpoint_requests",
            partial(
                self._meter.create_counter,
                name="count_endpoint_requests",
                description="Counts the number of endpoint requests",
                unit="1",
            ),
        )

    @property
    def file_process_bytes_histogram(self) -> Histogram:
        return self._get_or_create_metric(
            "hist_file_process_bytes",
            partial(
                self._meter.create_histogram,
                name="hist_file_process_bytes",
                description="Histogram for file process in bytes",
                unit="By",
            ),
        )

    # Database connection pool metrics
    # (includes engine_name)
    @property
    def db_pool_connections_total_gauge(self) -> Gauge:
        return self._get_or_create_metric(
            "gauge_db_pool_connections_total",
            partial(
                self._meter.create_gauge,
                name="gauge_db_pool_connections_total",
                description="Total number of connections in the database pool",
                unit="1",
            ),
        )

    # (includes engine_name)
    @property
    def db_pool_connections_checked_out_gauge(self) -> Gauge:
        return self._get_or_create_metric(
            "gauge_db_pool_connections_checked_out",
            partial(
                self._meter.create_gauge,
                name="gauge_db_pool_connections_checked_out",
                description="Number of connections currently checked out from the pool",
                unit="1",
            ),
        )

    # (includes engine_name)
    @property
    def db_pool_connections_available_gauge(self) -> Gauge:
        return self._get_or_create_metric(
            "gauge_db_pool_connections_available",
            partial(
                self._meter.create_gauge,
                name="gauge_db_pool_connections_available",
                description="Number of available connections in the pool",
                unit="1",
            ),
        )

    # (includes engine_name)
    @property
    def db_pool_connections_overflow_gauge(self) -> Gauge:
        return self._get_or_create_metric(
            "gauge_db_pool_connections_overflow",
            partial(
                self._meter.create_gauge,
                name="gauge_db_pool_connections_overflow",
                description="Number of overflow connections in the pool",
                unit="1",
            ),
        )

    # (includes engine_name)
    @property
    def db_pool_connection_duration_ms_histogram(self) -> Histogram:
        return self._get_or_create_metric(
            "hist_db_pool_connection_duration_ms",
            partial(
                self._meter.create_histogram,
                name="hist_db_pool_connection_duration_ms",
                description="Duration of database connection usage in milliseconds",
                unit="ms",
            ),
        )

    # (includes engine_name, event)
    @property
    def db_pool_connection_events_counter(self) -> Counter:
        return self._get_or_create_metric(
            "count_db_pool_connection_events",
            partial(
                self._meter.create_counter,
                name="count_db_pool_connection_events",
                description="Count of database connection pool events (connect, checkout, checkin, invalidate)",
                unit="1",
            ),
        )

    # (includes engine_name, exception_type)
    @property
    def db_pool_connection_errors_counter(self) -> Counter:
        return self._get_or_create_metric(
            "count_db_pool_connection_errors",
            partial(
                self._meter.create_counter,
                name="count_db_pool_connection_errors",
                description="Count of database connection pool errors",
                unit="1",
            ),
        )
